###### Replication file for Jeffrey M. Kaplow. "The Changing Face of Nuclear Proliferation." International Studies Quarterly.
# Produced using R version 3.5.2

### Packages
library(brglm) #v0.6.1
library(sandwich) #v2.5
library(lmtest) #v0.9-36
library(caret) #v6.0-81
library(RColorBrewer) #v1.1-2
library(santoku) #v0.10.0
library(dplyr) #v1.0.9
library(kernlab) #v0.9-27

# Needed to install DMwR 0.4.1
library(devtools)
remotes::install_github("cran/DMwR")
#

library(DMwR) #v0.4.1

### Functions

## robust: function to return cluster-robust standard errors
robust<-function(mod, clust){
  if(length(mod$y)!=length(clust)&!is.null(mod$na.action)) clust<-clust[-unique(mod$na.action)]
  if(length(mod$y)!=length(clust)&is.null(mod$na.action)) clust<-clust[as.numeric(names(mod$residuals))]
  M<-length(unique(clust))
  N<-length(clust)           
  K<-mod$rank
  dfc<-(M/(M - 1)) * ((N - 1)/(N - K))
  uj<-apply(estfun(mod), 2, function(x) tapply(x, clust, sum))
  rcv<-dfc * sandwich(mod, meat=crossprod(uj)/N)
  return(coeftest(mod, rcv))
}

## rwcv: function to implement the rolling window cross validation technique
rwcv <- function(dat, form, styear=1950, endyear=2010, winlength=5, formnames=names(form), smoteform=form[[1]]) {
  
  # Identify the DV and year
  dv <- all.vars(formula(form[[1]])[[2]])
  dvcolnum <- which(names(dat)==dv)
  yearcolnum <- which(names(dat)=="year")
  
  # Normalize variables for SVM
  xvars <- seq(1, ncol(dat), by=1)[-c(dvcolnum, yearcolnum)]
  normalization <- preProcess(dat[,xvars], method=c("center","scale"))
  dat[,xvars] <- predict(normalization, dat[,xvars])
  dat[[dvcolnum]]<-as.factor(dat[[dvcolnum]])
  
  # Build an appropriate matrix for the prediction values
  # Each observation is a train-test-start-year
  # Columns are a training year, testing year, and metric
  winseq <- seq(styear, endyear - (winlength-1))
  winnum <- length(winseq)
  predresults_total <- matrix(data = NA, nrow = winnum^2, ncol = 2 + (length(form)))
  colnames(predresults_total) <- c("trainyear","testyear",as.vector(t(outer(formnames, "PPV", paste, sep="_"))))
  predresults_total[,1] <- rep(seq(styear,endyear-(winlength-1)),each=winnum)
  predresults_total[,2] <- rep(seq(styear,endyear-(winlength-1)),winnum)
  
  # Create a list to save original probabilities for any later threshold adjustment
  predresults_total_prob <- lapply(vector("list", length(formnames)), as.list)
  names(predresults_total_prob) <- formnames
  
  # Create each training dataset and run the SMOTE algorithm
  for(i in styear:max(winseq)){
    traindat <- dat[which(dat[[yearcolnum]] >= i & dat[[yearcolnum]] <= (i + winlength - 1)), ]
    traindat <- traindat[,-yearcolnum]
    traindat <- SMOTE(formula(smoteform), data = traindat, perc.over = 1000, perc.under = 500)
    traindat <- na.omit(traindat)
    rownames(traindat) <- NULL
    
    # Create test data
    for(j in styear:max(winseq)){
      message(paste("Training ", i, ". Testing ", j, ".", sep=""),"\r", appendLF=TRUE)
      flush.console()
      testdat <- dat[which(dat[[yearcolnum]] >= j & dat[[yearcolnum]] <= (j + winlength - 1)), ]
      testdat <- testdat[,-yearcolnum]
      rownames(testdat) <- NULL
      truevals <- testdat[[dvcolnum]]
      
      predresults_per <- rep(NA, length(form))
      predresults_prob <- list()
      
      for(k in 1:length(form)) {
        trainmod <- train(formula(form[[k]]), data=traindat, scale=FALSE, "svmRadial", tuneLength = 4, prob.model=TRUE, trControl = trainControl(method="cv", verboseIter = FALSE, returnResamp = "all", classProbs=TRUE, savePredictions="all"))
        
        
        # Use the model to create predictions for the test data
        predvals <- predict(trainmod, newdata=testdat, type="prob")[,2]  
        
        # Drop missing values for calculating predictive validity
        badobs <- which(is.na(predvals))    
        if(length(badobs) > 0) {          
          truevals <- truevals[-badobs]     
          predvals <- predvals[-badobs]
        }
        
        # Create a matrix to hold predicted values
        if(is.factor(truevals)) { 
          oldlevs <- levels(truevals)       
          levels(truevals) <- c("0","1")
          probmat <- matrix(c(predvals, as.numeric(levels(truevals))[truevals]), nrow=length(predvals), ncol=2)
          levels(truevals) <- oldlevs
        }
        if(!is.factor(truevals)) probmat <- matrix(c(predvals, truevals), nrow=length(predvals), ncol=2)
        
        evalout <- list(result = NA, probmat = probmat)
        
        # Continue only if there are cases of proliferation among the true values
        if((length(which(truevals==1)) > 0) | (length(which(truevals=="yes")) > 0)) {
          
          # Convert predicted probabilities to yes or no
          predvals[which(predvals < 0.5)] <- 0
          predvals[which(predvals >= 0.5)] <- 1
          
          # Continue only if there are "yes" predictions (otherwise PPV is undefined)
          if(length(which(predvals == 1)) > 0) {
              predvals <- as.factor(predvals)
              levels(predvals) <- list(no="0", yes="1")
              
              # Create a confusion matrix and save PPV
              out <- confusionMatrix(as.factor(predvals), as.factor(truevals), positive="yes")
              evalout$result <- out$byClass[3]  
          }
        }
        
        predresults_per[k] <- evalout$result
        predresults_prob[[names(form)[k]]] <- evalout$probmat        
        
      }
      
      predresults_total[which(predresults_total[,1]==i & predresults_total[,2]==j), 3:ncol(predresults_total)] <- predresults_per
      runnum <- ((which(winseq %in% i) - 1) * length(winseq)) + which(winseq %in% j)
      for(f in 1:length(formnames)){
          predresults_total_prob[[f]][[runnum]] <- predresults_prob[[f]]
      }
    }
  }
  return(list(results=predresults_total, probmat=predresults_total_prob))
}

## tmove: a threshold adjustment procedure to change the arbitrary threshold for a positive prediction
tmove <- function(probmat, threshold=0.5) {
  
  result <- probmat[[1]]
  result[,-c(1,2)] <- NA
  
  for(i in 1:nrow(result)) {
    message(paste("Threshold Moving for Year ", i, sep=""),"\r", appendLF=TRUE)
    flush.console()
    for(j in 1:length(names(probmat[[2]]))) {
      usethresh <- threshold
      predvals <- probmat[[2]][[j]][[i]][,1]
      truevals <- probmat[[2]][[j]][[i]][,2]
      if(length(which(predvals >= threshold)) == 0) usethresh <- max(predvals) #Lower threshold so there are positive predictions 
      predvals[which(predvals < usethresh)]<-0
      predvals[which(predvals >= usethresh)]<-1
      if(length(which(truevals==1)) > 0 & length(which(predvals==1)) > 0) { # only proceed if there are "1" values among the predicted and right answers
        predvals <- as.factor(predvals)
        truevals <- as.factor(truevals)
        levels(predvals) <- list(no="0", yes="1")
        levels(truevals) <- list(no="0", yes="1")
        out <- confusionMatrix(predvals, truevals, positive="yes")
        result[i, (j + 2)] <- out$byClass[3]  # Calculate positive predictive value
      }
    }
    
  }
  return(result)  
}

## barcode: create barcode plots
barcode <- function(x, y, val, pal, main="") {
  
  # Divide the data into segments with an equal number of observations
  # Not all levels are used; this generates a warning that can be ignored
  cuts <- chop_equally(val, 9) 
  levels(cuts) <- seq(1:length(levels(cuts)))
  cutnum <- length(levels(cuts))
  
  # Load the color palette
  mypal <- brewer.pal(cutnum, pal)
  colvec <- mypal[cuts]
  colvec[which(x > y)] <- NA
  
  # Draw the plot
  plot(x, y, type="n", xlab="Training Start Year", ylab="Testing Start Year", main=main)
  rect(x-.5, y-.5, x+.5, y+.5, col=colvec, border=NA)
}


# Load data
load("Kaplow_ChangingFaceOfProlif_ISQ.RData")

# Scale GDP variable for more legible regression analysis
cfdat$realgdp <- cfdat$realgdp / 100000

# Create temporal subsamples for regression analysis, limited to states that have not yet acquired nuclear weapons
cfdat_all <- cfdat[which(cfdat$year >= 1950 & cfdat$year<= 2010 & cfdat$hasnuke==0),]
cfdat_early<-cfdat[which(cfdat$year >= 1950 & cfdat$year <= 1969 & cfdat$hasnuke==0),]
cfdat_middle<-cfdat[which(cfdat$year >= 1970 & cfdat$year <= 1989 & cfdat$hasnuke==0),]
cfdat_late<-cfdat[which(cfdat$year >= 1990 & cfdat$year <=2010 & cfdat$hasnuke==0),]

# Return GDP variable to original level
cfdat$realgdp <- cfdat$realgdp * 100000

## Table 1: Regression analysis of the drivers of proliferation by time period

# Model 1: 1950-2010
mod_all<-brglm(hasprog ~ tcfc + n_cap7_ext + realgdp + nukeprogriv + acdlast5 + nukepact + NPT + regime_emb + nonprogyears + nonprogyears2 + nonprogyears3 + progyears + progyears2 + progyears3, family=binomial(link="logit"), data=cfdat_all)
robust(mod_all, cfdat_all$ccode)
nrow(mod_all$model)

# Model 2: 1950-1969
mod_early<-brglm(hasprog ~ n_cap7_ext + realgdp + nukeprogriv + acdlast5 + nukepact + nonprogyears + nonprogyears2 + nonprogyears3 + progyears + progyears2 + progyears3, family=binomial(link="logit"), data=cfdat_early)
robust(mod_early, cfdat_early$ccode)
nrow(mod_early$model)

# Model 3: 1970-1989
mod_middle<-brglm(hasprog ~ tcfc + n_cap7_ext + realgdp + nukeprogriv + acdlast5 + nukepact + NPT + regime_emb + nonprogyears + nonprogyears2 + nonprogyears3 + progyears + progyears2 + progyears3, family=binomial(link="logit"), data=cfdat_middle)
robust(mod_middle, cfdat_middle$ccode)
nrow(mod_middle$model)

# Model 4: 1990-2010
mod_late<-brglm(hasprog ~ tcfc + n_cap7_ext + realgdp + nukeprogriv + acdlast5 + nukepact + NPT + regime_emb + nonprogyears + nonprogyears2 + nonprogyears3 + progyears + progyears2 + progyears3, family=binomial(link="logit"), data=cfdat_late)
robust(mod_late, cfdat_late$ccode)
nrow(mod_late$model)

## Rolling Window Cross Validation (RWCV)

# Assign a formula for each model
formfull <- list(full = "hasprog ~ tcfc + n_cap7_ext + realgdp + nukeprogriv + acdlast5 + nukepact + NPT + regime_emb",
               supply = "hasprog ~ tcfc + n_cap7_ext + realgdp",
               demand = "hasprog ~ nukeprogriv + acdlast5 + nukepact")
forminst <- list(inst = "hasprog ~ NPT + regime_emb")

# Cut down the data to only those non-missing observations we need
varlist <- c(all.vars(formula(formfull[[1]])), "year")
preddat <- na.omit(cfdat_all[,varlist])

# Adjust variable types
preddat$hasprog <- as.factor(preddat$hasprog) 
levels(preddat$hasprog) <- list(no="0", yes="1")
preddat$nukepact <- as.factor(preddat$nukepact) 
preddat$NPT <- as.factor(preddat$NPT)
preddat$acdlast5 <- as.factor(preddat$acdlast5) 
preddat$nukeprogriv <- as.factor(preddat$nukeprogriv)
row.names(preddat) <- NULL

# Run the RWCV procedure
####  NOTE: The next four lines take about 13 hours to run. The output from these lines is included in the replication dataset.
set.seed(999)
predout_full <- rwcv(dat = preddat, form = formfull, styear = 1950, endyear = 2010, winlength = 5)
set.seed(999)
predout_inst <- rwcv(dat = preddat, form = forminst, styear = 1957, endyear = 2010, winlength = 5)

# Resume here if using provided RWCV output
# Implement threshold adjustment procedure
res_full <- tmove(predout_full)
res_inst <- tmove(predout_inst)

## Figure 1
barcode(res_full[,1], res_full[,2], res_full[,3], pal="Blues")

## Figure 2
fig_out <- left_join(as.data.frame(res_full), as.data.frame(res_inst))
par(mfrow=c(3,1), mar=c(3,3,2,1))
barcode(fig_out[,1], fig_out[,2], fig_out[,4], pal="Reds", main="Nuclear capability") 
barcode(fig_out[,1], fig_out[,2], fig_out[,5], pal="Oranges", main="Nuclear motivation")
barcode(fig_out[,1], fig_out[,2], fig_out[,6], pal="Greens", main="International institutions")
par(mfrow=c(1,1))

## Figure 3
figsum <- fig_out %>% 
  group_by(testyear) %>% 
  summarize(avesupplyppv = mean(supply_PPV, na.rm=TRUE), avedemppv = mean(demand_PPV, na.rm=TRUE), aveinstppv = mean(inst_PPV, na.rm=TRUE))

plot(figsum$testyear, figsum$avesupplyppv, type="l", lwd=2, col="red",ylim=c(0,.4), cex.lab=0.7, cex.axis=0.7, ylab="Mean PPV for all Training Data", xlab="Testing Start Year")
lines(figsum$testyear, figsum$avedemppv, lwd=2, lty=2, col="orange")
lines(figsum$testyear, figsum$aveinstppv, lwd=2, lty=3, col="darkgreen")
legend(x="topright", legend=c("Supply", "Demand", "Institutions"), col=c("red", "orange", "darkgreen"), y.intersp=1, seg.len = 2.5,  lty=c(1,2,3), lwd=2, border=NA, cex=.6, bty="n")


